function [M,A4,A3B1,A2B2,AB2,B4] = fourth_moment(x,p)
% compute the fourth moment of the stochastic gradient error
% Input:
%   x [d,N] N vectors
%   p [1] batch size
% Output:
%   M    [1] E|A-B|^4
%   A4   [1] E|A|^4
%   A3B1 [1] E|A|^2(A.B)
%   A2B2 [1] E|A|^2|B|^2
%   AB2  [1] E(A.B)^2
%   B4   [1] |B|^4
N = size(x,2); % number of vectors
all_comb = nchoosek(1:N, p); % all combinations of size p
num_comb = size(all_comb, 1); % number of combinations
B = mean(x,2); % [d] average of all vectors
% ---- different expectations ---- %
M_sum = 0; % sum of |A-B|^4
A4_sum = 0; % sum of |A|^4
A3B1_sum = 0; % sum of |A|^2(A.B)
A2B2_sum = 0; % sum of |A|^2|B|^2
AB2_sum = 0; % sum of (A.B)^2
for i = 1:num_comb
    index = all_comb(i,:); % [p] index in the batch
    A = mean(x(:,index),2); % [d,1] vectors in the batch
    M_sum = M_sum + dot(A-B,A-B)^2;
    A4_sum = A4_sum + dot(A,A)^2;
    A3B1_sum = A3B1_sum + dot(A,A)*dot(A,B);
    A2B2_sum = A2B2_sum + dot(A,A)*dot(B,B);
    AB2_sum = AB2_sum + dot(A,B)^2;
end
M = M_sum / num_comb;
A4 = A4_sum / num_comb;
A3B1 = A3B1_sum / num_comb;
A2B2 = A2B2_sum / num_comb;
AB2 = AB2_sum / num_comb;
B4 = dot(B,B)^2;
end